#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Module: Processing beneficiary data. Creating mappings for beneficiary to ICD and DRG codes.
Used for creating beneficiary profile over years.
"""
__author__ = "shubhranshu-shekhar"
__date__ = "07/12/22"

import time
import bz2
import pickle
import json
from collections import Counter, defaultdict
import numpy as np
import scipy.sparse as sp
from functools import reduce
import pandas as pd


# importing already created module
try:
    from data_util import parse_args, mkdir_p, NumpyEncoder, read_data
except ImportError as e:
    from .data_util import parse_args, mkdir_p, NumpyEncoder, read_data


def get_beneficiary_age(args):
    path = args.bsf.replace('<PCT>', args.pct).replace('<year>', args.year)
    print(path)
    bsf = pd.read_stata(path)
    bsf = bsf[bsf['age'] >= 70].reset_index(drop=True) # restrict age to 70 or older
    age_dict = dict(zip(bsf['bene_id'], bsf['age']))
    return age_dict


def provider_beneficiary_cost(df, beneficiary_age_dct):
    bene_id = "bene_id"
    df['age'] = df[bene_id].map(beneficiary_age_dct)

    result_base = defaultdict(lambda: defaultdict(list))
    result_amt = defaultdict(lambda: defaultdict(list))
    for i, row in df.iterrows():
        result_base[row['provider']][row[bene_id]].append(row['base_clm_amt'])
        result_amt[row['provider']][row[bene_id]].append(row['pmt_amt'])

    return result_base, result_amt


def beneficiary_icd_drg_mapping(df, beneficiary_age_dct):
    # book keeping
    drg_col = "drg_cd"
    bene_id = "bene_id"

    icd9_cols = ['icd_dgns_cd1', 'icd_dgns_cd2', 'icd_dgns_cd3', 'icd_dgns_cd4', 'icd_dgns_cd5',
                 'icd_dgns_cd6', 'icd_dgns_cd7', 'icd_dgns_cd8', 'icd_dgns_cd9', 'icd_dgns_cd10']
    icd9_prcdrs = ['icd_prcdr_cd1', 'icd_prcdr_cd2', 'icd_prcdr_cd3', 'icd_prcdr_cd4', 'icd_prcdr_cd5',
                   'icd_prcdr_cd6']

    icd9_cols = [c for c in icd9_cols if c in df.columns]
    icd9_prcdrs = [c for c in icd9_prcdrs if c in df.columns]

    useful_columns = [drg_col] + icd9_cols + icd9_prcdrs

    df['age'] = df[bene_id].map(beneficiary_age_dct)
    df = df.dropna(subset=['age']).reset_index(drop=True)

    result = defaultdict(list)
    for i, row in df.iterrows():
        result[row[bene_id]].extend(row[useful_columns].values)

    return result


def beneficiary_op_cpt_mapping(args, beneficiary_age_dct):
    # book keeping
    bene_id = "bene_id"

    icd9_cols = ['icd_dgns_cd1', 'icd_dgns_cd2', 'icd_dgns_cd3', 'icd_dgns_cd4', 'icd_dgns_cd5',
                 'icd_dgns_cd6', 'icd_dgns_cd7', 'icd_dgns_cd8', 'icd_dgns_cd9', 'icd_dgns_cd10',
                 'icd_dgns_cd11', 'icd_dgns_cd12', 'icd_dgns_cd13', 'icd_dgns_cd14', 'icd_dgns_cd15',
                 'icd_dgns_cd16', 'icd_dgns_cd17', 'icd_dgns_cd18', 'icd_dgns_cd19', 'icd_dgns_cd20',
                 'icd_dgns_cd21', 'icd_dgns_cd22', 'icd_dgns_cd23', 'icd_dgns_cd24', 'icd_dgns_cd25']
    icd9_prcdrs = ['icd_prcdr_cd1', 'icd_prcdr_cd2', 'icd_prcdr_cd3', 'icd_prcdr_cd4', 'icd_prcdr_cd5',
                   'icd_prcdr_cd6',
                   'icd_prcdr_cd7', 'icd_prcdr_cd8', 'icd_prcdr_cd9', 'icd_prcdr_cd10', 'icd_prcdr_cd11',
                   'icd_prcdr_cd12', 'icd_prcdr_cd13', 'icd_prcdr_cd14', 'icd_prcdr_cd15', 'icd_prcdr_cd16',
                   'icd_prcdr_cd17', 'icd_prcdr_cd18', 'icd_prcdr_cd19', 'icd_prcdr_cd20', 'icd_prcdr_cd21',
                   'icd_prcdr_cd22', 'icd_prcdr_cd23', 'icd_prcdr_cd24', 'icd_prcdr_cd25']

    useful_columns = icd9_cols + icd9_prcdrs
    select_cols = [bene_id] + icd9_cols + icd9_prcdrs

    # read outpatient data
    input_path = args.op
    input_path = input_path.replace('<year>', args.year).replace('<PCT>', args.pct)
    print('Reading OP data from:', input_path)
    df = pd.read_stata(input_path, columns=select_cols)

    df['age'] = df[bene_id].map(beneficiary_age_dct)
    df = df.dropna(subset=['age']).reset_index(drop=True)

    # collect each column into a single column
    df['selected'] = df[useful_columns].values.tolist()
    gdf = df[[bene_id, 'selected']].groupby(bene_id).agg({"selected": "sum"})

    print('Number of rows in the outpatient data: ', df.shape[0])
    result = gdf['selected'].to_dict()
    result = {k: list(filter(None, v)) for k, v in result.items()}

    return result


def create_aggregate_mat(mat_yr_cat, col_map, years=('2017', '2016', '2015', '2014', '2013', '2012')):
    mat_yr_cat = mat_yr_cat.tocsr()
    yearly_mat_lst = []
    yearly_col_counts = [len(col_map[k]) for k in col_map.keys()]
    for i, n_cols in enumerate(yearly_col_counts):
        mat = mat_yr_cat[:, i*n_cols:(i+1)*n_cols]
        yearly_mat_lst.append(mat)

    # creating global column list across years
    global_col_keys = list(reduce(set.union, [set(v.keys()) for yr, v in col_map.items()]))
    global_col_map = dict(zip(list(global_col_keys), range(len(global_col_keys))))

    n_cols_global = len(global_col_map)

    inverted_col_map = {}
    for yr in col_map:
        yr_col_map = col_map[yr]
        tmp_ = {v: k for k, v in yr_col_map.items()}
        inverted_col_map[yr] = tmp_

    # creating new global matrix
    global_yearly_mat_lst = []
    for yr, m_ in zip(years, yearly_mat_lst):
        rows, cols, vals = [], [], []
        m_ = sp.coo_matrix(m_)
        for i, j, v in zip(m_.row, m_.col, m_.data):
            rows.append(i)
            cols.append(global_col_map[inverted_col_map[yr][j]])
            vals.append(v)
        m_ = sp.csr_matrix((vals, (rows, cols)), shape=(m_.shape[0], n_cols_global))
        global_yearly_mat_lst.append(m_)

    mat = global_yearly_mat_lst[0]
    for m_ in global_yearly_mat_lst[1:]:
        mat += m_

    return mat, global_col_map


def create_sum_total_except_current_codes():
    # load files created by sum total
    print('Loading year-wise ICD files')
    with bz2.BZ2File('../data/bene_icd_mat_all.bz2.pkl', 'rb') as f:
        [X_icd, bene_icd_row_maps, bene_icd_col_maps] = pickle.load(f)

    print("Removing ICD mat for year 2017...")
    X_icd = X_icd.tocsr()
    n_cols_2017 = len(bene_icd_col_maps['2017'])
    X_icd = X_icd[:, n_cols_2017:]  # first n_cols_2017 cols belong to year 2017

    bene_icd_col_maps.pop('2017')  # update column map
    print("Updated ICD col years:", list(bene_icd_col_maps.keys()))

    print("Combining ICD mat...")
    X_icd_updated, global_icd_col_map_updated = create_aggregate_mat(X_icd, bene_icd_col_maps,
                                                     years=('2016', '2015', '2014', '2013'))

    with bz2.BZ2File('../data/icd_mat_all_sum_no_curr_codes.bz2.pkl', 'wb') as f:
        pickle.dump([X_icd_updated, global_icd_col_map_updated], f)
    print("Done.")

    print("Beginning to create large matrix without the current year DRG and ICD codes...")

    print('Loading car files')
    with bz2.BZ2File('../data/car_mat_all_sum.bz2.pkl', 'rb') as f:
        [X_car, global_car_col_map] = pickle.load(f)
    print("Done.")

    print('Loading OP CPT files')
    with bz2.BZ2File('../data/op_mat_all_sum.bz2.pkl', 'rb') as f:
        [X_op, global_op_col_map] = pickle.load(f)
    print("Done.")

    print('Loading Chronic files')
    with bz2.BZ2File('../data/bene_chronic_mat.bz2.pkl', 'rb') as f:
        [mat_bene_chronic, master_bene_ids_lst] = pickle.load(f)
    print("Done.")

    print('Beneficiary provider files')
    with bz2.BZ2File('../data/bene_prvdr_mat.bz2.pkl', 'rb') as f:
        [mat_bene_prvdr, bene_prvdr_row_map, bene_prvdr_col_map] = pickle.load(f)

    with bz2.BZ2File('../data/sum_mat_target.bz2.pkl', 'rb') as f:
        [_, target] = pickle.load(f)

    print("Saving sum total matrix without the DRG and ICD code from year 2017...")
    X = sp.hstack([X_car, X_icd_updated, X_op, mat_bene_chronic, mat_bene_prvdr])
    X = X.tocsr()

    with bz2.BZ2File('../data/sum_mat_target_no_curr_codes.bz2.pkl', 'wb') as f:
        pickle.dump([X, target], f)

    print('Done')


def create_sum_total_data_matrix():
    print('Loading car files')
    with bz2.BZ2File('../data/bene_car_mat_all.bz2.pkl', 'rb') as f:
        [X_car, bene_car_row_maps, bene_car_col_maps] = pickle.load(f)

    print('Loading ICD files')
    with bz2.BZ2File('../data/bene_icd_mat_all.bz2.pkl', 'rb') as f:
        [X_icd, bene_icd_row_maps, bene_icd_col_maps] = pickle.load(f)

    print('Loading OP CPT files')
    with bz2.BZ2File('../data/bene_op_mat_all.bz2.pkl', 'rb') as f:
        [X_op, bene_op_row_maps, bene_op_col_maps] = pickle.load(f)

    print('Loading Chronic files')
    with bz2.BZ2File('../data/bene_chronic_mat.bz2.pkl', 'rb') as f:
        [mat_bene_chronic, master_bene_ids_lst] = pickle.load(f)

    print("Assert that row and columns are straight")
    assert (master_bene_ids_lst == list(bene_op_row_maps['2017'].keys()) == list(bene_car_row_maps['2013'].keys()))
    master_bene_ids_set = set(master_bene_ids_lst)

    print("Combining car mat...")
    X_car, global_car_col_map = create_aggregate_mat(X_car, bene_car_col_maps)
    with bz2.BZ2File('../data/car_mat_all_sum.bz2.pkl', 'wb') as f:
        pickle.dump([X_car, global_car_col_map], f)
    print("Done.")

    print("Combining ICD mat...")
    X_icd, global_icd_col_map = create_aggregate_mat(X_icd, bene_icd_col_maps)
    with bz2.BZ2File('../data/icd_mat_all_sum.bz2.pkl', 'wb') as f:
        pickle.dump([X_icd, global_icd_col_map], f)
    print("Done.")

    print("Combining OP mat...")
    X_op, global_op_col_map = create_aggregate_mat(X_op, bene_op_col_maps)
    with bz2.BZ2File('../data/op_mat_all_sum.bz2.pkl', 'wb') as f:
        pickle.dump([X_op, global_op_col_map], f)
    print("Done.")

    print('Processing target to have one-hot encoded provider and total amount spent')
    prvdr_bene_base = json.load(open('../output/2017/provider_beneficiary_base.json'))
    bene_prvdr_dct = defaultdict(dict)
    bene_amount = defaultdict(float)
    for pr, b_dct in prvdr_bene_base.items():
        for b_, amt in b_dct.items():
            if b_ in master_bene_ids_set:
                bene_prvdr_dct[b_][pr] = len(amt)
                bene_amount[b_] += sum(amt)

    # now create bene provider matrix
    # master_bene_ids_lst
    bene_prvdr_row_map = dict(zip(master_bene_ids_lst, range(len(master_bene_ids_lst))))
    bene_prvdr_col_keys = list(reduce(set.union, [set(v.keys()) for b_, v in bene_prvdr_dct.items()]))
    bene_prvdr_col_map = dict(zip(list(bene_prvdr_col_keys), range(len(bene_prvdr_col_keys))))

    rows, cols, vals = [], [], []
    for b_, p_dct in bene_prvdr_dct.items():
        for p_, cnt_ in p_dct.items():
            rows.append(bene_prvdr_row_map[b_])
            cols.append(bene_prvdr_col_map[p_])
            vals.append(cnt_)
    mat_bene_prvdr = sp.csr_matrix((vals, (rows, cols)))
    print("Bene provider matrix", mat_bene_prvdr.shape)
    with bz2.BZ2File('../data/bene_prvdr_mat.bz2.pkl', 'wb') as f:
        pickle.dump([mat_bene_prvdr, bene_prvdr_row_map, bene_prvdr_col_map], f)

    X = sp.hstack([X_car, X_icd, X_op, mat_bene_chronic, mat_bene_prvdr]) 
    X = X.tocsr()

    target = [bene_amount[k] for k in master_bene_ids_lst]
    with bz2.BZ2File('../data/sum_mat_target.bz2.pkl', 'wb') as f:
        pickle.dump([X, target], f) # includes all years


def main():
    args = parse_args()
    args.savepath = '../output/'

    PCT = '100pct'

    # process 2017 data first to get beneficiary age
    yr = '2017'
    start = time.time()  # to see how much time each yesr requires
    save_path = args.savepath + yr + '/'
    mkdir_p(save_path)

    # read data per year
    args.year = yr
    args.pct = PCT
    print("Reading data...")
    bsf_age_dct = get_beneficiary_age(args)
    json.dump(bsf_age_dct, open(save_path + "beneficiary_age.json", 'w'), cls=NumpyEncoder)

    end = time.time()
    print('Elapsed time for year {} is {} seconds.'.format(yr, end - start))

    # load 70 years or older beneficiaries
    beneficiary_age_dct = json.load(open(args.savepath + '2017' + '/' + "beneficiary_age.json"))
    print('Loaded beneficiary age mapping.')

    for yr in ['2017', '2016', '2015', '2014', '2013', '2012']:
        start = time.time()  # to see how much time each yesr requires
        save_path = args.savepath + yr + '/'
        mkdir_p(save_path)

        # read data per year
        args.year = yr
        args.pct = PCT

        # print("Reading data...")
        df = read_data(args, yr)

        print('Computing beneficiary ICD and DRG')
        beneficiary_icd_drg_dstn = beneficiary_icd_drg_mapping(df, beneficiary_age_dct)
        json.dump(beneficiary_icd_drg_dstn, open(save_path + "beneficiary_icd_drg_mapping.json", 'w'),
                  cls=NumpyEncoder)

        # saving Outpatient codes per beneficiary
        bene_op_cpt_dct = beneficiary_op_cpt_mapping(args, beneficiary_age_dct)
        json.dump(bene_op_cpt_dct, open(save_path + "beneficiary_op_cpt_mapping.json", 'w'), cls=NumpyEncoder)

        if yr == '2017':
            print('Saving regression targets...')
            # save claim and base claim amount per provider and beneficiary
            result_base, result_amt = provider_beneficiary_cost(df, beneficiary_age_dct)
            json.dump(result_base, open(save_path + "provider_beneficiary_base.json", 'w'), cls=NumpyEncoder)
            json.dump(result_amt, open(save_path + "provider_beneficiary_pmt.json", 'w'), cls=NumpyEncoder)

        end = time.time()
        print('Elapsed time for year {} is {} seconds.'.format(yr, end - start))
    
    # process regression data
    create_sum_total_except_current_codes()  # creating sum total without current year DRG and ICD counts


if __name__ == '__main__':
    main()
